// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli    Codeplay Software Ltd.
// Ralph Potter  Codeplay Software Ltd.
// Luke Iwanski  Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

/*****************************************************************
 * TensorScanSycl.h
 *
 * \brief:
 *  Tensor Scan Sycl implement the extend  version of
 * "Efficient parallel scan algorithms for GPUs." .for Tensor operations.
 * The algorithm requires up to 3 stage (consequently 3 kernels) depending on
 * the size of the tensor. In the first kernel (ScanKernelFunctor), each
 * threads within the work-group individually reduces the allocated elements per
 * thread in order to reduces the total number of blocks. In the next step all
 * thread within the work-group will reduce the associated blocks into the
 * temporary buffers. In the next kernel(ScanBlockKernelFunctor), the temporary
 * buffer is given as an input and all the threads within a work-group scan and
 * reduces the boundaries between the blocks (generated from the previous
 * kernel). and write the data on the temporary buffer. If the second kernel is
 * required, the third and final kerenl (ScanAdjustmentKernelFunctor) will
 * adjust the final result into the output buffer.
 * The original algorithm for the parallel prefix sum can be found here:
 *
 * Sengupta, Shubhabrata, Mark Harris, and Michael Garland. "Efficient parallel
 * scan algorithms for GPUs." NVIDIA, Santa Clara, CA, Tech. Rep. NVR-2008-003
 *1, no. 1 (2008): 1-17.
 *****************************************************************/

#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_SYCL_SYCL_HPP
#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_SYCL_SYCL_HPP

namespace Eigen {
namespace TensorSycl {
namespace internal {

#ifndef EIGEN_SYCL_MAX_GLOBAL_RANGE
#define EIGEN_SYCL_MAX_GLOBAL_RANGE (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 * 4)
#endif

template<typename index_t>
struct ScanParameters
{
	// must be power of 2
	static EIGEN_CONSTEXPR index_t ScanPerThread = 8;
	const index_t total_size;
	const index_t non_scan_size;
	const index_t scan_size;
	const index_t non_scan_stride;
	const index_t scan_stride;
	const index_t panel_threads;
	const index_t group_threads;
	const index_t block_threads;
	const index_t elements_per_group;
	const index_t elements_per_block;
	const index_t loop_range;

	ScanParameters(index_t total_size_,
				   index_t non_scan_size_,
				   index_t scan_size_,
				   index_t non_scan_stride_,
				   index_t scan_stride_,
				   index_t panel_threads_,
				   index_t group_threads_,
				   index_t block_threads_,
				   index_t elements_per_group_,
				   index_t elements_per_block_,
				   index_t loop_range_)
		: total_size(total_size_)
		, non_scan_size(non_scan_size_)
		, scan_size(scan_size_)
		, non_scan_stride(non_scan_stride_)
		, scan_stride(scan_stride_)
		, panel_threads(panel_threads_)
		, group_threads(group_threads_)
		, block_threads(block_threads_)
		, elements_per_group(elements_per_group_)
		, elements_per_block(elements_per_block_)
		, loop_range(loop_range_)
	{
	}
};

enum class scan_step
{
	first,
	second
};
template<typename Evaluator, typename CoeffReturnType, typename OutAccessor, typename Op, typename Index, scan_step stp>
struct ScanKernelFunctor
{
	typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
		LocalAccessor;
	static EIGEN_CONSTEXPR int PacketSize = ScanParameters<Index>::ScanPerThread / 2;

	LocalAccessor scratch;
	Evaluator dev_eval;
	OutAccessor out_accessor;
	OutAccessor temp_accessor;
	const ScanParameters<Index> scanParameters;
	Op accumulator;
	const bool inclusive;
	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScanKernelFunctor(LocalAccessor scratch_,
															const Evaluator dev_eval_,
															OutAccessor out_accessor_,
															OutAccessor temp_accessor_,
															const ScanParameters<Index> scanParameters_,
															Op accumulator_,
															const bool inclusive_)
		: scratch(scratch_)
		, dev_eval(dev_eval_)
		, out_accessor(out_accessor_)
		, temp_accessor(temp_accessor_)
		, scanParameters(scanParameters_)
		, accumulator(accumulator_)
		, inclusive(inclusive_)
	{
	}

	template<scan_step sst = stp, typename Input>
	typename ::Eigen::internal::enable_if<sst == scan_step::first, CoeffReturnType>::type EIGEN_DEVICE_FUNC
		EIGEN_STRONG_INLINE
		read(const Input& inpt, Index global_id)
	{
		return inpt.coeff(global_id);
	}

	template<scan_step sst = stp, typename Input>
	typename ::Eigen::internal::enable_if<sst != scan_step::first, CoeffReturnType>::type EIGEN_DEVICE_FUNC
		EIGEN_STRONG_INLINE
		read(const Input& inpt, Index global_id)
	{
		return inpt[global_id];
	}

	template<scan_step sst = stp, typename InclusiveOp>
	typename ::Eigen::internal::enable_if<sst == scan_step::first>::type EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
	first_step_inclusive_Operation(InclusiveOp inclusive_op)
	{
		inclusive_op();
	}

	template<scan_step sst = stp, typename InclusiveOp>
	typename ::Eigen::internal::enable_if<sst != scan_step::first>::type EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
	first_step_inclusive_Operation(InclusiveOp)
	{
	}

	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID)
	{
		auto out_ptr = out_accessor.get_pointer();
		auto tmp_ptr = temp_accessor.get_pointer();
		auto scratch_ptr = scratch.get_pointer().get();

		for (Index loop_offset = 0; loop_offset < scanParameters.loop_range; loop_offset++) {
			Index data_offset = (itemID.get_global_id(0) + (itemID.get_global_range(0) * loop_offset));
			Index tmp = data_offset % scanParameters.panel_threads;
			const Index panel_id = data_offset / scanParameters.panel_threads;
			const Index group_id = tmp / scanParameters.group_threads;
			tmp = tmp % scanParameters.group_threads;
			const Index block_id = tmp / scanParameters.block_threads;
			const Index local_id = tmp % scanParameters.block_threads;
			// we put one element per packet in scratch_mem
			const Index scratch_stride = scanParameters.elements_per_block / PacketSize;
			const Index scratch_offset = (itemID.get_local_id(0) / scanParameters.block_threads) * scratch_stride;
			CoeffReturnType private_scan[ScanParameters<Index>::ScanPerThread];
			CoeffReturnType inclusive_scan;
			// the actual panel size is scan_size * non_scan_size.
			// elements_per_panel is roundup to power of 2 for binary tree
			const Index panel_offset = panel_id * scanParameters.scan_size * scanParameters.non_scan_size;
			const Index group_offset = group_id * scanParameters.non_scan_stride;
			// This will be effective when the size is bigger than elements_per_block
			const Index block_offset = block_id * scanParameters.elements_per_block * scanParameters.scan_stride;
			const Index thread_offset = (ScanParameters<Index>::ScanPerThread * local_id * scanParameters.scan_stride);
			const Index global_offset = panel_offset + group_offset + block_offset + thread_offset;
			Index next_elements = 0;
			EIGEN_UNROLL_LOOP
			for (int i = 0; i < ScanParameters<Index>::ScanPerThread; i++) {
				Index global_id = global_offset + next_elements;
				private_scan[i] =
					((((block_id * scanParameters.elements_per_block) +
					   (ScanParameters<Index>::ScanPerThread * local_id) + i) < scanParameters.scan_size) &&
					 (global_id < scanParameters.total_size))
						? read(dev_eval, global_id)
						: accumulator.initialize();
				next_elements += scanParameters.scan_stride;
			}
			first_step_inclusive_Operation([&]() EIGEN_DEVICE_FUNC {
				if (inclusive) {
					inclusive_scan = private_scan[ScanParameters<Index>::ScanPerThread - 1];
				}
			});
			// This for loop must be 2
			EIGEN_UNROLL_LOOP
			for (int packetIndex = 0; packetIndex < ScanParameters<Index>::ScanPerThread; packetIndex += PacketSize) {
				Index private_offset = 1;
				// build sum in place up the tree
				EIGEN_UNROLL_LOOP
				for (Index d = PacketSize >> 1; d > 0; d >>= 1) {
					EIGEN_UNROLL_LOOP
					for (Index l = 0; l < d; l++) {
						Index ai = private_offset * (2 * l + 1) - 1 + packetIndex;
						Index bi = private_offset * (2 * l + 2) - 1 + packetIndex;
						CoeffReturnType accum = accumulator.initialize();
						accumulator.reduce(private_scan[ai], &accum);
						accumulator.reduce(private_scan[bi], &accum);
						private_scan[bi] = accumulator.finalize(accum);
					}
					private_offset *= 2;
				}
				scratch_ptr[2 * local_id + (packetIndex / PacketSize) + scratch_offset] =
					private_scan[PacketSize - 1 + packetIndex];
				private_scan[PacketSize - 1 + packetIndex] = accumulator.initialize();
				// traverse down tree & build scan
				EIGEN_UNROLL_LOOP
				for (Index d = 1; d < PacketSize; d *= 2) {
					private_offset >>= 1;
					EIGEN_UNROLL_LOOP
					for (Index l = 0; l < d; l++) {
						Index ai = private_offset * (2 * l + 1) - 1 + packetIndex;
						Index bi = private_offset * (2 * l + 2) - 1 + packetIndex;
						CoeffReturnType accum = accumulator.initialize();
						accumulator.reduce(private_scan[ai], &accum);
						accumulator.reduce(private_scan[bi], &accum);
						private_scan[ai] = private_scan[bi];
						private_scan[bi] = accumulator.finalize(accum);
					}
				}
			}

			Index offset = 1;
			// build sum in place up the tree
			for (Index d = scratch_stride >> 1; d > 0; d >>= 1) {
				// Synchronise
				itemID.barrier(cl::sycl::access::fence_space::local_space);
				if (local_id < d) {
					Index ai = offset * (2 * local_id + 1) - 1 + scratch_offset;
					Index bi = offset * (2 * local_id + 2) - 1 + scratch_offset;
					CoeffReturnType accum = accumulator.initialize();
					accumulator.reduce(scratch_ptr[ai], &accum);
					accumulator.reduce(scratch_ptr[bi], &accum);
					scratch_ptr[bi] = accumulator.finalize(accum);
				}
				offset *= 2;
			}
			// Synchronise
			itemID.barrier(cl::sycl::access::fence_space::local_space);
			// next step optimisation
			if (local_id == 0) {
				if (((scanParameters.elements_per_group / scanParameters.elements_per_block) > 1)) {
					const Index temp_id =
						panel_id * (scanParameters.elements_per_group / scanParameters.elements_per_block) *
							scanParameters.non_scan_size +
						group_id * (scanParameters.elements_per_group / scanParameters.elements_per_block) + block_id;
					tmp_ptr[temp_id] = scratch_ptr[scratch_stride - 1 + scratch_offset];
				}
				// clear the last element
				scratch_ptr[scratch_stride - 1 + scratch_offset] = accumulator.initialize();
			}
			// traverse down tree & build scan
			for (Index d = 1; d < scratch_stride; d *= 2) {
				offset >>= 1;
				// Synchronise
				itemID.barrier(cl::sycl::access::fence_space::local_space);
				if (local_id < d) {
					Index ai = offset * (2 * local_id + 1) - 1 + scratch_offset;
					Index bi = offset * (2 * local_id + 2) - 1 + scratch_offset;
					CoeffReturnType accum = accumulator.initialize();
					accumulator.reduce(scratch_ptr[ai], &accum);
					accumulator.reduce(scratch_ptr[bi], &accum);
					scratch_ptr[ai] = scratch_ptr[bi];
					scratch_ptr[bi] = accumulator.finalize(accum);
				}
			}
			// Synchronise
			itemID.barrier(cl::sycl::access::fence_space::local_space);
			// This for loop must be 2
			EIGEN_UNROLL_LOOP
			for (int packetIndex = 0; packetIndex < ScanParameters<Index>::ScanPerThread; packetIndex += PacketSize) {
				EIGEN_UNROLL_LOOP
				for (Index i = 0; i < PacketSize; i++) {
					CoeffReturnType accum = private_scan[packetIndex + i];
					accumulator.reduce(scratch_ptr[2 * local_id + (packetIndex / PacketSize) + scratch_offset], &accum);
					private_scan[packetIndex + i] = accumulator.finalize(accum);
				}
			}
			first_step_inclusive_Operation([&]() EIGEN_DEVICE_FUNC {
				if (inclusive) {
					accumulator.reduce(private_scan[ScanParameters<Index>::ScanPerThread - 1], &inclusive_scan);
					private_scan[0] = accumulator.finalize(inclusive_scan);
				}
			});
			next_elements = 0;
			// right the first set of private param
			EIGEN_UNROLL_LOOP
			for (Index i = 0; i < ScanParameters<Index>::ScanPerThread; i++) {
				Index global_id = global_offset + next_elements;
				if ((((block_id * scanParameters.elements_per_block) +
					  (ScanParameters<Index>::ScanPerThread * local_id) + i) < scanParameters.scan_size) &&
					(global_id < scanParameters.total_size)) {
					Index private_id =
						(i * !inclusive) + (((i + 1) % ScanParameters<Index>::ScanPerThread) * (inclusive));
					out_ptr[global_id] = private_scan[private_id];
				}
				next_elements += scanParameters.scan_stride;
			}
		} // end for loop
	}
};

template<typename CoeffReturnType, typename InAccessor, typename OutAccessor, typename Op, typename Index>
struct ScanAdjustmentKernelFunctor
{
	typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
		LocalAccessor;
	static EIGEN_CONSTEXPR int PacketSize = ScanParameters<Index>::ScanPerThread / 2;
	InAccessor in_accessor;
	OutAccessor out_accessor;
	const ScanParameters<Index> scanParameters;
	Op accumulator;
	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScanAdjustmentKernelFunctor(LocalAccessor,
																	  InAccessor in_accessor_,
																	  OutAccessor out_accessor_,
																	  const ScanParameters<Index> scanParameters_,
																	  Op accumulator_)
		: in_accessor(in_accessor_)
		, out_accessor(out_accessor_)
		, scanParameters(scanParameters_)
		, accumulator(accumulator_)
	{
	}

	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID)
	{
		auto in_ptr = in_accessor.get_pointer();
		auto out_ptr = out_accessor.get_pointer();

		for (Index loop_offset = 0; loop_offset < scanParameters.loop_range; loop_offset++) {
			Index data_offset = (itemID.get_global_id(0) + (itemID.get_global_range(0) * loop_offset));
			Index tmp = data_offset % scanParameters.panel_threads;
			const Index panel_id = data_offset / scanParameters.panel_threads;
			const Index group_id = tmp / scanParameters.group_threads;
			tmp = tmp % scanParameters.group_threads;
			const Index block_id = tmp / scanParameters.block_threads;
			const Index local_id = tmp % scanParameters.block_threads;

			// the actual panel size is scan_size * non_scan_size.
			// elements_per_panel is roundup to power of 2 for binary tree
			const Index panel_offset = panel_id * scanParameters.scan_size * scanParameters.non_scan_size;
			const Index group_offset = group_id * scanParameters.non_scan_stride;
			// This will be effective when the size is bigger than elements_per_block
			const Index block_offset = block_id * scanParameters.elements_per_block * scanParameters.scan_stride;
			const Index thread_offset = ScanParameters<Index>::ScanPerThread * local_id * scanParameters.scan_stride;

			const Index global_offset = panel_offset + group_offset + block_offset + thread_offset;
			const Index block_size = scanParameters.elements_per_group / scanParameters.elements_per_block;
			const Index in_id =
				(panel_id * block_size * scanParameters.non_scan_size) + (group_id * block_size) + block_id;
			CoeffReturnType adjust_val = in_ptr[in_id];

			Index next_elements = 0;
			EIGEN_UNROLL_LOOP
			for (Index i = 0; i < ScanParameters<Index>::ScanPerThread; i++) {
				Index global_id = global_offset + next_elements;
				if ((((block_id * scanParameters.elements_per_block) +
					  (ScanParameters<Index>::ScanPerThread * local_id) + i) < scanParameters.scan_size) &&
					(global_id < scanParameters.total_size)) {
					CoeffReturnType accum = adjust_val;
					accumulator.reduce(out_ptr[global_id], &accum);
					out_ptr[global_id] = accumulator.finalize(accum);
				}
				next_elements += scanParameters.scan_stride;
			}
		}
	}
};

template<typename Index>
struct ScanInfo
{
	const Index& total_size;
	const Index& scan_size;
	const Index& panel_size;
	const Index& non_scan_size;
	const Index& scan_stride;
	const Index& non_scan_stride;

	Index max_elements_per_block;
	Index block_size;
	Index panel_threads;
	Index group_threads;
	Index block_threads;
	Index elements_per_group;
	Index elements_per_block;
	Index loop_range;
	Index global_range;
	Index local_range;
	const Eigen::SyclDevice& dev;
	EIGEN_STRONG_INLINE ScanInfo(const Index& total_size_,
								 const Index& scan_size_,
								 const Index& panel_size_,
								 const Index& non_scan_size_,
								 const Index& scan_stride_,
								 const Index& non_scan_stride_,
								 const Eigen::SyclDevice& dev_)
		: total_size(total_size_)
		, scan_size(scan_size_)
		, panel_size(panel_size_)
		, non_scan_size(non_scan_size_)
		, scan_stride(scan_stride_)
		, non_scan_stride(non_scan_stride_)
		, dev(dev_)
	{
		// must be power of 2
		local_range = std::min(Index(dev.getNearestPowerOfTwoWorkGroupSize()),
							   Index(EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1));

		max_elements_per_block = local_range * ScanParameters<Index>::ScanPerThread;

		elements_per_group =
			dev.getPowerOfTwo(Index(roundUp(Index(scan_size), ScanParameters<Index>::ScanPerThread)), true);
		const Index elements_per_panel = elements_per_group * non_scan_size;
		elements_per_block = std::min(Index(elements_per_group), Index(max_elements_per_block));
		panel_threads = elements_per_panel / ScanParameters<Index>::ScanPerThread;
		group_threads = elements_per_group / ScanParameters<Index>::ScanPerThread;
		block_threads = elements_per_block / ScanParameters<Index>::ScanPerThread;
		block_size = elements_per_group / elements_per_block;
#ifdef EIGEN_SYCL_MAX_GLOBAL_RANGE
		const Index max_threads = std::min(Index(panel_threads * panel_size), Index(EIGEN_SYCL_MAX_GLOBAL_RANGE));
#else
		const Index max_threads = panel_threads * panel_size;
#endif
		global_range = roundUp(max_threads, local_range);
		loop_range = Index(
			std::ceil(double(elements_per_panel * panel_size) / (global_range * ScanParameters<Index>::ScanPerThread)));
	}
	inline ScanParameters<Index> get_scan_parameter()
	{
		return ScanParameters<Index>(total_size,
									 non_scan_size,
									 scan_size,
									 non_scan_stride,
									 scan_stride,
									 panel_threads,
									 group_threads,
									 block_threads,
									 elements_per_group,
									 elements_per_block,
									 loop_range);
	}
	inline cl::sycl::nd_range<1> get_thread_range()
	{
		return cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
	}
};

template<typename EvaluatorPointerType, typename CoeffReturnType, typename Reducer, typename Index>
struct SYCLAdjustBlockOffset
{
	EIGEN_STRONG_INLINE static void adjust_scan_block_offset(EvaluatorPointerType in_ptr,
															 EvaluatorPointerType out_ptr,
															 Reducer& accumulator,
															 const Index total_size,
															 const Index scan_size,
															 const Index panel_size,
															 const Index non_scan_size,
															 const Index scan_stride,
															 const Index non_scan_stride,
															 const Eigen::SyclDevice& dev)
	{
		auto scan_info =
			ScanInfo<Index>(total_size, scan_size, panel_size, non_scan_size, scan_stride, non_scan_stride, dev);

		typedef ScanAdjustmentKernelFunctor<CoeffReturnType, EvaluatorPointerType, EvaluatorPointerType, Reducer, Index>
			AdjustFuctor;
		dev.template unary_kernel_launcher<CoeffReturnType, AdjustFuctor>(in_ptr,
																		  out_ptr,
																		  scan_info.get_thread_range(),
																		  scan_info.max_elements_per_block,
																		  scan_info.get_scan_parameter(),
																		  accumulator);
	}
};

template<typename CoeffReturnType, scan_step stp>
struct ScanLauncher_impl
{
	template<typename Input, typename EvaluatorPointerType, typename Reducer, typename Index>
	EIGEN_STRONG_INLINE static void scan_block(Input in_ptr,
											   EvaluatorPointerType out_ptr,
											   Reducer& accumulator,
											   const Index total_size,
											   const Index scan_size,
											   const Index panel_size,
											   const Index non_scan_size,
											   const Index scan_stride,
											   const Index non_scan_stride,
											   const bool inclusive,
											   const Eigen::SyclDevice& dev)
	{
		auto scan_info =
			ScanInfo<Index>(total_size, scan_size, panel_size, non_scan_size, scan_stride, non_scan_stride, dev);
		const Index temp_pointer_size = scan_info.block_size * non_scan_size * panel_size;
		const Index scratch_size = scan_info.max_elements_per_block / (ScanParameters<Index>::ScanPerThread / 2);
		CoeffReturnType* temp_pointer =
			static_cast<CoeffReturnType*>(dev.allocate_temp(temp_pointer_size * sizeof(CoeffReturnType)));
		EvaluatorPointerType tmp_global_accessor = dev.get(temp_pointer);

		typedef ScanKernelFunctor<Input, CoeffReturnType, EvaluatorPointerType, Reducer, Index, stp> ScanFunctor;
		dev.template binary_kernel_launcher<CoeffReturnType, ScanFunctor>(in_ptr,
																		  out_ptr,
																		  tmp_global_accessor,
																		  scan_info.get_thread_range(),
																		  scratch_size,
																		  scan_info.get_scan_parameter(),
																		  accumulator,
																		  inclusive);

		if (scan_info.block_size > 1) {
			ScanLauncher_impl<CoeffReturnType, scan_step::second>::scan_block(tmp_global_accessor,
																			  tmp_global_accessor,
																			  accumulator,
																			  temp_pointer_size,
																			  scan_info.block_size,
																			  panel_size,
																			  non_scan_size,
																			  Index(1),
																			  scan_info.block_size,
																			  false,
																			  dev);

			SYCLAdjustBlockOffset<EvaluatorPointerType, CoeffReturnType, Reducer, Index>::adjust_scan_block_offset(
				tmp_global_accessor,
				out_ptr,
				accumulator,
				total_size,
				scan_size,
				panel_size,
				non_scan_size,
				scan_stride,
				non_scan_stride,
				dev);
		}
		dev.deallocate_temp(temp_pointer);
	}
};

} // namespace internal
} // namespace TensorSycl
namespace internal {
template<typename Self, typename Reducer, bool vectorize>
struct ScanLauncher<Self, Reducer, Eigen::SyclDevice, vectorize>
{
	typedef typename Self::Index Index;
	typedef typename Self::CoeffReturnType CoeffReturnType;
	typedef typename Self::Storage Storage;
	typedef typename Self::EvaluatorPointerType EvaluatorPointerType;
	void operator()(Self& self, EvaluatorPointerType data)
	{
		const Index total_size = internal::array_prod(self.dimensions());
		const Index scan_size = self.size();
		const Index scan_stride = self.stride();
		// this is the scan op (can be sum or ...)
		auto accumulator = self.accumulator();
		auto inclusive = !self.exclusive();
		auto consume_dim = self.consume_dim();
		auto dev = self.device();

		auto dims = self.inner().dimensions();

		Index non_scan_size = 1;
		Index panel_size = 1;
		if (static_cast<int>(Self::Layout) == static_cast<int>(ColMajor)) {
			for (int i = 0; i < consume_dim; i++) {
				non_scan_size *= dims[i];
			}
			for (int i = consume_dim + 1; i < Self::NumDims; i++) {
				panel_size *= dims[i];
			}
		} else {
			for (int i = Self::NumDims - 1; i > consume_dim; i--) {
				non_scan_size *= dims[i];
			}
			for (int i = consume_dim - 1; i >= 0; i--) {
				panel_size *= dims[i];
			}
		}
		const Index non_scan_stride = (scan_stride > 1) ? 1 : scan_size;
		auto eval_impl = self.inner();
		TensorSycl::internal::ScanLauncher_impl<CoeffReturnType, TensorSycl::internal::scan_step::first>::scan_block(
			eval_impl,
			data,
			accumulator,
			total_size,
			scan_size,
			panel_size,
			non_scan_size,
			scan_stride,
			non_scan_stride,
			inclusive,
			dev);
	}
};
} // namespace internal
} // namespace Eigen

#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_SYCL_SYCL_HPP
